package com.jstarcraft.ai.jsat.math.optimization;

import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.math.Function;
import com.jstarcraft.ai.jsat.math.FunctionVec;

/**
 * An implementation of Backtraking line search using the Armijo rule. The
 * search for alpha is done by quadratic and cubic interpolation without using
 * any derivative evaluations.
 * 
 * @author Edward Raff
 */
public class BacktrackingArmijoLineSearch implements LineSearch {
    private double rho;
    private double c1;

    /**
     * Creates a new Backtracking line search
     */
    public BacktrackingArmijoLineSearch() {
        this(0.5, 1e-1);
    }

    /**
     * Creates a new Backtracking line search object
     * 
     * @param rho constant to decrease alpha by in (0, 1) when interpolation is not
     *            possible
     * @param c1  the <i>sufficient decrease condition</i> condition constant in (0,
     *            1/2)
     */
    public BacktrackingArmijoLineSearch(double rho, double c1) {
        if (!(rho > 0 && rho < 1))
            throw new IllegalArgumentException("rho must be in (0,1), not " + rho);
        this.rho = rho;
        setC1(c1);
    }

    /**
     * Sets the constant used for the <i>sufficient decrease condition</i>
     * f(x+&alpha; p) &le; f(x) + c<sub>1</sub> &alpha; p<sup>T</sup>&nabla;f(x)
     * 
     * @param c1 the <i>sufficient decrease condition</i>
     */
    public void setC1(double c1) {
        if (c1 <= 0 || c1 >= 0.5)
            throw new IllegalArgumentException("c1 must be in (0, 1/2) not " + c1);
        this.c1 = c1;
    }

    /**
     * Returns the <i>sufficient decrease condition</i> constant
     * 
     * @return the <i>sufficient decrease condition</i> constant
     */
    public double getC1() {
        return c1;
    }

    @Override
    public double lineSearch(double alpha_max, Vec x_k, Vec x_grad, Vec p_k, Function f, FunctionVec fp, double f_x, double gradP, Vec x_alpha_pk, double[] fxApRet, Vec grad_x_alpha_pk, boolean parallel) {
        if (Double.isNaN(f_x))
            f_x = f.f(x_k, parallel);
        if (Double.isNaN(gradP))
            gradP = x_grad.dot(p_k);

        double alpha = alpha_max;
        if (x_alpha_pk == null)
            x_alpha_pk = x_k.clone();
        else
            x_k.copyTo(x_alpha_pk);
        x_alpha_pk.mutableAdd(alpha, p_k);
        double f_xap = f.f(x_alpha_pk, parallel);
        if (fxApRet != null)
            fxApRet[0] = f_xap;
        double oldAlpha = 0;
        double oldF_xap = f_x;

        while (f_xap > f_x + c1 * alpha * gradP)// we return start if its already good
        {
            final double tooSmall = 0.1 * alpha;
            final double tooLarge = 0.9 * alpha;
            // see INTERPOLATION section of chapter 3.5
            // XXX double compare.
            if (alpha == alpha_max)// quadratic interpolation
            {
                double alphaCandidate = -gradP * oldAlpha * oldAlpha / (2 * (f_xap - f_x - gradP * oldAlpha));
                oldAlpha = alpha;
                if (alphaCandidate < tooSmall || alphaCandidate > tooLarge || Double.isNaN(alphaCandidate)) {
                    alpha = rho * oldAlpha;
                } else {
                    alpha = alphaCandidate;
                }
            } else// cubic interpoation
            {
                // g = φ(α1)−φ(0)−φ'(0)α1
                double g = f_xap - f_x - gradP * alpha;
                // h = φ(α0) − φ(0) − φ'(0)α0
                double h = oldF_xap - f_x - gradP * oldAlpha;

                double a0Sqrd = oldAlpha * oldAlpha;
                double a1Sqrd = alpha * alpha;

                double a = a0Sqrd * g - a1Sqrd * h;
                a /= (a0Sqrd * a1Sqrd * (alpha - oldAlpha));
                double b = -a0Sqrd * oldAlpha * g + a1Sqrd * alpha * h;
                b /= (a0Sqrd * a1Sqrd * (alpha - oldAlpha));

                double alphaCandidate = (-b + Math.sqrt(b * b - 3 * a * gradP)) / (3 * a);
                oldAlpha = alpha;
                if (alphaCandidate < tooSmall || alphaCandidate > tooLarge || Double.isNaN(alphaCandidate)) {
                    alpha = rho * oldAlpha;
                } else {
                    alpha = alphaCandidate;
                }

            }

            if (alpha < 1e-20)
                return oldAlpha;
            x_alpha_pk.mutableSubtract(oldAlpha - alpha, p_k);
            oldF_xap = f_xap;
            f_xap = f.f(x_alpha_pk, parallel);
            if (fxApRet != null)
                fxApRet[0] = f_xap;
        }

        return alpha;
    }

    @Override
    public boolean updatesGrad() {
        return false;
    }

    @Override
    public BacktrackingArmijoLineSearch clone() {
        return new BacktrackingArmijoLineSearch(rho, c1);
    }
}
